Skip to content

Add SANA-WM camera-controlled image-to-video pipeline#13881

Open
lawrence-cj wants to merge 7 commits into
huggingface:mainfrom
lawrence-cj:feat/sana-wm-diffusers-cleanup
Open

Add SANA-WM camera-controlled image-to-video pipeline#13881
lawrence-cj wants to merge 7 commits into
huggingface:mainfrom
lawrence-cj:feat/sana-wm-diffusers-cleanup

Conversation

@lawrence-cj

@lawrence-cj lawrence-cj commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Hi @sayakpaul @dg845 , Long time no see. Hoping your are doing great. ♥️

Adds SANA-WM, the camera-controlled image-to-video world model from NVIDIA + MIT HAN Lab, as a first-class diffusers pipeline and transformer. Given a first-frame image, a text prompt, and a camera trajectory (explicit c2w poses or a WASD/IJKL action-DSL string), the pipeline generates a video whose motion follows the requested camera path. Trained natively for minute-scale generation at 704×1280.

The pipeline runs in two stages:

  1. Stage 1 — SanaWMTransformer3DModel. A 1.6B-parameter bidirectional DiT with GDN-Triton linear attention and a UCPE camera-control branch; samples with an LTX-style flow-matching Euler scheduler at per-token timesteps. The first latent frame is the conditioning anchor.
  2. Stage 2 — SanaWMLTX2Refiner (optional). A chunk-causal AR refiner that wraps diffusers' LTX2VideoTransformer3DModel + LTX2TextConnectors + Gemma-3 text encoder. Processes 3 latent frames at a time with a sliding window of [source_sink + recent_history + active_block] K/V, so per-block compute is bounded and total refinement cost is linear in video length.

Both stages decode through AutoencoderKLLTX2Video.

Layout

src/diffusers/
├── models/transformers/
│   ├── transformer_sana_wm.py          # SanaWMTransformer3DModel + blocks + helpers
│   └── transformer_sana_wm_kernels.py  # fused Triton kernels + camera math
└── pipelines/sana_wm/
    ├── __init__.py
    ├── pipeline_sana_wm.py             # SanaWMPipeline
    ├── pipeline_output.py              # SanaWMPipelineOutput
    ├── refiner.py                      # SanaWMLTX2Refiner + RefinerChunkRunner
    └── cam_utils.py                    # action DSL, intrinsics, resize+crop, Plücker/raymap

scripts/sana_wm/convert_sana_wm_to_diffusers.py
docs/source/en/api/{pipelines/sana_wm.md, models/sana_wm_transformer3d.md}

Usage

import torch
from PIL import Image
from diffusers import SanaWMPipeline
from diffusers.utils import export_to_video

pipe = SanaWMPipeline.from_pretrained(
    "Efficient-Large-Model/SANA-WM_bidirectional-diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.vae.to(torch.float32)
pipe.enable_model_cpu_offload()

out = pipe(
    image=Image.open("input.png").convert("RGB"),
    prompt="A car driving across a vast desert plain at golden hour.",
    action="w-80,jw-40,w-40",                    # WASD-style action DSL
    intrinsics=[800.0, 800.0, 845.0, 464.0],      # fx, fy, cx, cy in original-image pixels
    num_frames=161,
    num_inference_steps=60,
)
export_to_video(list(out.frames), "sana_wm.mp4", fps=16)

Demo

5-second sample (30 stage-1 steps + 3-step distilled AR refiner, official asset/sana_wm/demo_0 inputs, 704×1280 @ 16 fps) :

sana_wm_5s.mp4

Smoke tests

End-to-end on 1× H100 80GB with `enable_model_cpu_offload` and the official `asset/sana_wm/demo_0.{png,txt,_pose.npy,_intrinsics.npy}`:

Duration Frames Stage-1 (30 steps) Refiner (AR, 3 blocks) Output
5s 80 1:11 5:24 / step 525 KB
10s 160 1:11 28:55 (7 blocks) 1.4 MB
20s 320 1:57 ≈ 4 min / block (14) 3.2 MB
50s 800 5:33 30:46 (34 blocks) 6.3 MB

Checkpoint conversion

scripts/sana_wm/convert_sana_wm_to_diffusers.py --src Efficient-Large-Model/SANA-WM_bidirectional --dst /local/path converts the public release into a `from_pretrained`-loadable directory (VAE, Gemma-2 tokenizer + text_encoder, transformer, scheduler, refiner subfolders, top-level `model_index.json`).

Related

Paper: https://arxiv.org/abs/2605.15178

HaoyiZhu and others added 4 commits June 1, 2026 01:28
…line

Adds the public SANA-WM bidirectional camera-controlled image-to-video
model as a first-class diffusers pipeline + transformer. Layout mirrors
``sana_video``: the model lives under ``src/diffusers/models/transformers/``
as a near-single-file (kernels split off so the ``@triton.jit`` decorators
don't drown the model body); the pipeline lives under
``src/diffusers/pipelines/sana_wm/``.

Files added:

  src/diffusers/models/transformers/
  ├── transformer_sana_wm.py         # SanaWMTransformer3DModel + blocks + helpers
  └── transformer_sana_wm_kernels.py # fused Triton kernels + camera math

  src/diffusers/pipelines/sana_wm/
  ├── __init__.py
  ├── pipeline_sana_wm.py
  ├── pipeline_output.py
  ├── refiner.py
  └── cam_utils.py

Pipeline architecture:
* Stage 1: 1600M ``SanaWMTransformer3DModel`` DiT with bidirectional
  GDN-Triton linear attention + UCPE camera-control branch, LTX-style
  flow-matching Euler scheduler with per-token timesteps.
* Stage 2: LTX-2 sink-bidirectional Euler refiner (3 distilled sigma
  steps, reuses diffusers' ``LTX2VideoTransformer3DModel`` +
  ``LTX2TextConnectors`` + Gemma-3 text encoder).
* Decode through the LTX-2 VAE (``AutoencoderKLLTX2Video``).

One-line usage:

  pipe = SanaWMPipeline.from_pretrained(
      "Efficient-Large-Model/SANA-WM_bidirectional-diffusers",
      torch_dtype=torch.bfloat16,
  ).to("cuda")
  out = pipe(image=img, prompt="...", action="w-80,jw-40,w-40",
             intrinsics=[fx, fy, cx, cy])

End-to-end smoke test (stage-1 + refiner + VAE decode) passes on H100.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…xport

transformer_sana_wm.py:
* License header switched to the "HuggingFace Team and SANA-WM Authors"
  style used by merged sana_video.
* Imports rewritten in stdlib -> third-party -> diffusers order; use
  diffusers `from ...utils import logging` instead of stdlib `logging`.
* Fix 9 `Optional[X]` annotations written as `X or None` (Python's `or`
  short-circuits and silently returns `X`).
* Fix two `assert (cond, msg)` tuple-asserts in PatchEmbedMS3D.forward
  that always pass (SyntaxWarning at import time).
* Remove duplicate `__all__` declarations (the second silently overwrote
  the first).
* Remove dead `reset_bn` (imports a nonexistent `packages.apps.utils`,
  would crash on call).
* Remove the duplicate `logger = logging.getLogger(__name__)` further
  down in the file.

transformer_sana_wm_kernels.py:
* License header normalized; collapse three duplicate triton/torch import
  blocks into one.

pipeline_sana_wm.py:
* License header normalized.
* `_decode_latents` now returns `(T, H, W, 3)` float in [0, 1], matching
  the diffusers convention used by `VideoProcessor`. Returning uint8
  silently broke `export_to_video`: it does `frame * 255` assuming float
  input, so uint8 overflows to `(-x) mod 256` and inverts colors.
* `__call__` converts to PIL/uint8 only when `output_type="pil"`.
* Intrinsics argument now accepts (4,), (F, 4), (3, 3), and (F, 3, 3)
  forms (auto-extracts fx, fy, cx, cy from a 3x3 K) and auto-trims to
  `num_frames` when a longer-than-needed trajectory is passed.
* Inline `retrieve_timesteps` with the standard `# Copied from
  diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps`
  marker, matching merged sana_video.
* Docstrings + EXAMPLE_DOC_STRING updated to reflect the new return type.

pipeline_output.py:
* Update `frames` field docstring to describe the new float [0, 1] return.

refiner.py, cam_utils.py, scripts/sana_wm/convert_sana_wm_to_diffusers.py:
* License headers normalized.

Docs:
* New `docs/source/en/api/pipelines/sana_wm.md` and
  `docs/source/en/api/models/sana_wm_transformer3d.md`, modeled on
  sana_video.md / sana_video_transformer3d.md, wired into
  `docs/source/en/_toctree.yml` under Models and Pipelines.

5s end-to-end smoke test (81 frames @ 16fps, 30 stage-1 steps + 3-step
LTX-2 refiner) passes on 1x H100 80GB with `enable_model_cpu_offload`.
Round-trip diff vs raw float frames is 2.06/255 mean (h264 lossy noise),
confirming the export_to_video fix.
…+ KV cache hooks)

The first cleanup pass only kept the legacy single-shot refiner path. That
path is what the model was *not* trained on — its docstring even says
"feeding the full sequence at once is out-of-distribution" — and its cost
is O(T^2) attention over the full latent volume, which made longer videos
unusable (~21 min per refiner step at 321 frames on an H100).

Port the chunk-causal AR mode from the upstream reference so the refiner
matches the training contract:

* `refine_latents` now defaults to `block_size=3, kv_max_frames=11`
  (the canonical AR recipe). Pass `block_size=None` to fall back to the
  legacy single-shot path.
* New `_refine_latents_ar` + `_RefinerChunkRunner` orchestrate the sliding
  window: pre-capture pre-RoPE sink K/V on `z_sana[:source_sink_frames]`
  at sigma=0, then for each `block_size`-frame chunk run a 3-step Euler
  with prefix `{sink_k_pre, sink_v, sink_pe, history_k, history_v}` and
  capture post-RoPE K/V to feed the next window. History is bounded to
  `kv_max_frames - source_sink_frames` so per-block compute is constant.
* New `_predict_x0_active_block` runs the transformer on the active block
  only (Q from active, K/V from prefix+active).
* New `_capture_block_kv` runs sigma=0 forward with a pre_rope/post_rope
  capture flag set on each `attn1`.
* New `_forward_video_only_with_rope` takes a pre-built RoPE so each block
  can use absolute frame positions in the source video.
* `_streaming_self_attention` extended with the `_kv_cache_capture`,
  `_tf_capture_kv`, `_tf_kv_prefix` hook contract that AR mode uses to
  inject and capture K/V on each block.
* New helpers: `_build_rotary_emb_for_absolute_positions`,
  `_set_kv_prefix_on_blocks`, `_clear_kv_prefix_on_blocks`,
  `_set_capture_flag_on_blocks`, `_collect_captured_kv_from_blocks`.
* `_encode_prompt` now also moves the Gemma-3 text encoder back to CPU
  after producing the embeds — otherwise it stays resident through the
  entire AR loop and gates how much GPU memory the refiner transformer
  has left.

Module-level docstring updated to document both modes; existing
single-shot path preserved verbatim.
…eemption)

The AR refiner is expensive (~3-5 min per block) and the refinement loop
ran end-to-end has no in-progress state to recover, so a SLURM preemption
mid-refinement loses all progress. With the canonical
``block_size=3, kv_max_frames=11`` setup, refining a 50s video is 34
blocks of work that has to make it through without preemption on a
backfill queue.

Add per-block atomic checkpointing:

* ``SanaWMLTX2Refiner.refine_latents(checkpoint_dir=Path)`` and
  ``_refine_latents_ar`` accept a directory. After each completed AR
  block, the AR loop writes ``checkpoint_dir/state.pt`` atomically
  (tmp + os.replace).
* The payload is ``{block_idx_done, n_blocks, sink_size, block_size,
  output_shape, output, runner_state}``. ``runner_state`` is a CPU snapshot
  of the runner's ``_sink_kv_pre``, ``_history_kv_post``,
  ``_history_frames`` and ``torch.Generator`` state.
* On entry, if ``state.pt`` exists with a compatible shape signature, the
  AR loop loads the persisted output tensor + runner state and resumes
  from ``block_idx_done + 1`` instead of recomputing from scratch.
* ``SanaWMPipeline.__call__(refiner_checkpoint_dir=...)`` plumbs the
  directory through to the refiner.

Checkpoint size: ~output_volume + sink_KV (~360MB for 50 layers) +
rolling history KV (~3-4GB at full capacity) — saved once per block,
total per-block save overhead ~10s on lustre.
@github-actions github-actions Bot added size/L PR with diff > 200 LOC documentation Improvements or additions to documentation models pipelines and removed size/L PR with diff > 200 LOC labels Jun 7, 2026
@github-actions github-actions Bot added the size/L PR with diff > 200 LOC label Jun 9, 2026
* CPU unit tests for cam_utils helpers (action DSL → c2w, intrinsics
  rescale-for-crop, resize+center-crop, snap_num_frames 8k+1 rounding).
* Public-surface registration tests (top-level diffusers symbols,
  SanaWMPipelineOutput dataclass shape, refiner signature has AR defaults
  + checkpoint_dir, pipeline __call__ accepts c2w/action/intrinsics/
  refiner_checkpoint_dir).
* @slow @require_torch_accelerator integration stub for an end-to-end I2V
  against the public checkpoint, currently @unittest.skip — wires up the
  nightly GPU path without exploding regular CI.

SanaWMTransformer3DModel has hardcoded depth/hidden_size/num_heads inside
its inner SanaMSVideoCamCtrl (not exposed through register_to_config), so
the usual PipelineTesterMixin small-config fast tests aren't applicable
without a transformer refactor (followup PR).
@github-actions github-actions Bot added the tests label Jun 9, 2026
@dg845 dg845 requested review from dg845 and yiyixuxu June 12, 2026 03:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants